import json
from datetime import datetime
import capstone
from x64dbg import Module, Debug, Memory, Register


# capstone setup
cap = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
cap.detail = True


class SnapShot:
    def __init__(self, frame: int) -> None:
        self.frame = frame
        self.rip = Register.Get(Register.RegisterEnum.RIP)
        self.rax = Register.Get(Register.RegisterEnum.RAX)
        self.rbx = Register.Get(Register.RegisterEnum.RBX)
        self.rcx = Register.Get(Register.RegisterEnum.RCX)
        self.rdx = Register.Get(Register.RegisterEnum.RDX)
        self.rbp = Register.Get(Register.RegisterEnum.RBP)
        self.rsp = Register.Get(Register.RegisterEnum.RSP)
        self.rsi = Register.Get(Register.RegisterEnum.RSI)
        self.rdi = Register.Get(Register.RegisterEnum.RDI)
        self.r8 = Register.Get(Register.RegisterEnum.R8)
        self.r9 = Register.Get(Register.RegisterEnum.R9)
        self.r10 = Register.Get(Register.RegisterEnum.R10)
        self.r11 = Register.Get(Register.RegisterEnum.R11)
        self.r12 = Register.Get(Register.RegisterEnum.R12)
        self.r13 = Register.Get(Register.RegisterEnum.R13)
        self.r14 = Register.Get(Register.RegisterEnum.R14)
        self.r15 = Register.Get(Register.RegisterEnum.R15)
        self.instruction = get_assembly_at(self.rip)
        self.operands = get_mem_value(self.instruction)

    def __str__(self):
        ops_str = ', '.join(f'{o:x}' for o in self.operands)
        return f"{self.frame:3} {self.rip:x} {self.instruction.mnemonic} {self.instruction.op_str}: {ops_str}"
    

    def json(self):
        res = {k: v for k, v in self.__dict__.items()}
        res["instruction"] = f"{self.instruction.mnemonic} {self.instruction.op_str}"
        return res


def get_function_address(module, name: str) -> int:
    import_list = Module.GetImports(module)
    for item in import_list:
        if item.name == name:
            return Memory.ReadQword(item.iatVa)
        

def in_module(module, addr: int) -> bool:
    return module.base <= addr <= module.base + module.size


def in_shellcode(shellcode: int, addr: int, size: int = 0x800000):
    return shellcode <= addr <= shellcode + size


def get_assembly_at(addr: int, cap=cap) -> capstone.CsInsn:
    raw = Memory.Read(addr, 16)
    inst = next(cap.disasm(raw, 0x0))
    return inst


def get_reg_value(regname: str) -> int:
    return Register.Get(getattr(Register.RegisterEnum, regname.upper()))


read_func = {
    1: Memory.ReadByte,
    2: Memory.ReadWord,
    4: Memory.ReadDword,
    8: Memory.ReadQword,
}

def get_mem_value(instr: capstone.CsInsn) -> list[int]:
    res = []
    for op in instr.operands:
        if op.type == capstone.x86.X86_OP_REG:
            #res.append(Register.Get(Register.RegisterEnum))
            regname = instr.reg_name(op.reg)
            if instr.mnemonic == "pop":
                res.append(Memory.ReadQword(get_reg_value("rsp")))
            else:
                res.append(Register.Get(getattr(Register.RegisterEnum, regname.upper())))
        elif op.type == capstone.x86.X86_OP_MEM:
            regname = instr.reg_name(op.mem.base)
            regvalue = Register.Get(getattr(Register.RegisterEnum, regname.upper()))
            memaddr = regvalue + op.mem.disp
            if instr.mnemonic == "lea":
                res.append(memaddr)
            else:
                res.append(read_func[op.size](memaddr))
        elif op.type == capstone.x86.X86_OP_IMM:
            res.append(op.imm)
        else:
            res.append(op.type)
    return res



def main():

    # set up breakpoints
    main_module = Module.GetMainModuleInfo()
    virtual_alloc_addr = get_function_address(main_module, "VirtualAlloc")
    Debug.SetBreakpoint(virtual_alloc_addr)
    Debug.SetBreakpoint(main_module.entry)
    Debug.DisableBreakpoint(0x1400014f0)

    # get shellcode buffer location
    Debug.Run()
    Debug.StepOut()
    shellcode = Register.Get(Register.RegisterEnum.RAX)
    print(f"Shellcode buffer: 0x{shellcode:x}")

    # get breakpoint on return to shellcode from ntdll
    Debug.Run()
    Debug.SetBreakpoint(shellcode + 0x98)
    Debug.Run()
    Debug.Run()  # needed if hlt breaks are on
    rsp = Register.Get(Register.RegisterEnum.RSP)
    print(f"{rsp=:x}")
    call_rax = Memory.ReadQword(rsp) - 2
    print(f"{call_rax=:x}")
    Debug.SetBreakpoint(Memory.ReadQword(Register.Get(Register.RegisterEnum.RSP)) - 2)

    i = 0
    frames = []
    while True:
        try:
            rip = Register.Get(Register.RegisterEnum.RIP)
            if in_module(main_module, rip) or in_shellcode(shellcode, rip):
                frame = SnapShot(i)
                frames.append(frame)
                print(frame)
                if frame.instruction.mnemonic == "cmovne" and frame.operands[1] == 0x1400011f0:
                    inst = frame.instruction
                    target_reg = inst.reg_name(inst.operands[0].reg)
                    target_reg_value = Register.Get(getattr(Register.RegisterEnum, target_reg.upper()))
                    source_reg = inst.reg_name(inst.operands[1].reg)
                    source_reg_value = Register.Get(getattr(Register.RegisterEnum, source_reg.upper()))

                    set_source_reg = getattr(Register, f"Set{source_reg.upper()}")
                    set_source_reg(target_reg_value)

                    new_source_reg_value = Register.Get(getattr(Register.RegisterEnum, source_reg.upper()))

                    print(f"Updated {source_reg} from {hex(source_reg_value)} to {hex(new_source_reg_value)}")
                Debug.StepIn()
            else:
                i += 1
                print(f"frame {i}")
                Debug.Run()
                Debug.StepIn()
        except Exception as exc:
            print(exc)
            break

    print(len(frames))
    with open(f"Z:\\flareon-2024\\9-serpentine\\{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-all.json", "w") as f:
        json.dump([frame.json() for frame in frames], f)

if __name__ == "__main__":
    main()